"""
# Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import json
import re
from typing import Any, List, Optional

import paddle
import torch

from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
    BackendBase, BaseChecker, LogitsProcessorBase)
from fastdeploy.utils import llm_logger

try:
    from xgrammar import (CompiledGrammar, Grammar, GrammarCompiler,
                          GrammarMatcher, StructuralTagItem, TokenizerInfo,
                          allocate_token_bitmask, apply_token_bitmask_inplace)
except Exception as e:
    raise Exception(
        f"import XGrammar failed, please check your environment:\n\t {e}")


class XGrammarProcessor(LogitsProcessorBase):
    """
    XGrammar-specific implementation of LogitsProcessorBase.

    This processor enforces grammar constraints during token generation using XGrammar.
    It manages the grammar matching state and applies token masks to logits.

    Attributes:
        max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
        vocab_size (int): Size of the vocabulary
        batch_size (int): Batch size for processing
        splitwise_role (str): Role for splitwise processing
        compiled_grammar (CompiledGrammar): Compiled grammar rules
        terminate_without_stop_token (bool): Whether to terminate without stop token
        override_stop_tokens (Optional[List[int]]): Custom stop tokens
        matcher (GrammarMatcher): Grammar matching engine
    """

    def __init__(
        self,
        compiled_grammar: CompiledGrammar,
        terminate_without_stop_token: bool = False,
        override_stop_tokens: Optional[List[int]] = None,
        vocab_size: Optional[int] = None,
        batch_size: Optional[int] = None,
        splitwise_role: str = "mixed",
    ):
        super().__init__()
        self.max_rollback_tokens = 200
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.splitwise_role = splitwise_role
        self.compiled_grammar = compiled_grammar
        self.terminate_without_stop_token = terminate_without_stop_token
        self.override_stop_tokens = override_stop_tokens

        self.matcher = GrammarMatcher(
            compiled_grammar=compiled_grammar,
            max_rollback_tokens=self.max_rollback_tokens,
            terminate_without_stop_token=terminate_without_stop_token,
            override_stop_tokens=override_stop_tokens,
        )

    def allocate_token_bitmask(self) -> torch.Tensor:
        """
        Allocate a token bitmask tensor for grammar constraints.

        Returns:
            torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
        """
        return allocate_token_bitmask(self.batch_size, self.vocab_size)

    def fill_token_bitmask(self, token_bitmask: torch.Tensor,
                           idx: int) -> None:
        """
        Fill the token bitmask with allowed tokens for the given index.

        Args:
            token_bitmask (torch.Tensor): The token bitmask tensor to fill
            idx (int): The batch index to fill the mask for

        Returns:
            None: Modifies the token_bitmask in-place
        """
        self.matcher.fill_next_token_bitmask(token_bitmask, idx)

    def apply_token_mask(
        self,
        logits: paddle.Tensor,
        token_bitmask: torch.Tensor,
        indices: Optional[List[int]] = None,
    ) -> paddle.Tensor:
        """
        Apply the token mask to the logits, modifying probabilities of invalid tokens.

        Args:
            logits (paddle.Tensor): The logits tensor to modify
            token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
            indices (Optional[List[int]]): Optional list of batch indices to apply mask to

        Returns:
            paddle.Tensor: The modified logits tensor
        """
        origin_place = logits.place
        origin_dtype = logits.dtype
        logits = torch.from_numpy(logits.numpy())

        logits = logits.float()  # cpu
        apply_token_bitmask_inplace(
            logits=logits,
            bitmask=token_bitmask.to(logits.device, non_blocking=True),
            indices=indices,
        )

        return paddle.to_tensor(
            logits.numpy(),
            dtype=origin_dtype,
            place=origin_place,
        )

    def reset(self) -> None:
        """
        Reset the grammar matcher state to initial conditions.

        Returns:
            None: No return value
        """
        self.matcher.reset()

    def accept_token(self, token: int) -> None:
        """
        Validate and accept a generated token against the grammar constraints.

        Args:
            token (int): The token ID to validate

        Raises:
            AssertionError: If token is not allowed by the grammar
        """
        assert self.matcher.accept_token(
            token), f"Failed to accept token {token}"

    def is_terminated(self) -> bool:
        """
        Check if the grammar matching process has terminated.

        Returns:
            bool: True if matching has terminated, False otherwise
        """
        return self.matcher.is_terminated()

    def copy(self) -> "XGrammarProcessor":
        """
        Create a deep copy of this processor instance.

        Returns:
            XGrammarProcessor: A new processor instance with identical state
        """
        return XGrammarProcessor(
            compiled_grammar=self.compiled_grammar,
            terminate_without_stop_token=self.terminate_without_stop_token,
            override_stop_tokens=self.override_stop_tokens,
            vocab_size=self.vocab_size,
            batch_size=self.batch_size,
            splitwise_role=self.splitwise_role,
        )


class XGrammarBackend(BackendBase):
    """
    XGrammar-specific implementation of BackendBase.

    This backend handles compilation of various schema types (JSON, regex, grammar)
    into XGrammar processors. It manages the grammar compiler and tokenizer info.

    Attributes:
        vocab_size (int): Size of the vocabulary from config
        batch_size (int): Maximum batch size from config
        any_whitespace (bool): Whether to allow any whitespace in JSON
        splitwise_role (str): Role for splitwise processing
        grammar_compiler (GrammarCompiler): Grammar compilation engine
    """

    def __init__(
        self,
        fd_config: FDConfig,
        **kwargs,
    ):
        super().__init__(fd_config=fd_config)
        self.vocab_size = fd_config.model_config.vocab_size
        self.batch_size = fd_config.parallel_config.max_num_seqs

        self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
        self.splitwise_role = fd_config.parallel_config.splitwise_role

        try:
            tokenizer_info = TokenizerInfo.from_huggingface(
                self.hf_tokenizer, vocab_size=self.vocab_size)
            self.grammar_compiler = GrammarCompiler(
                tokenizer_info=tokenizer_info)
        except Exception as e:
            raise Exception(f"Failed to load XGrammar tokenizer: {e}")

    def _create_processor(
        self,
        compiled_grammar: CompiledGrammar,
        terminate_without_stop_token: bool = False,
        override_stop_tokens: Optional[List[int]] = None,
    ) -> XGrammarProcessor:
        """
        Create a logits processor instance for the given compiled grammar.

        Args:
            compiled_grammar (CompiledGrammar): Compiled grammar rules
            terminate_without_stop_token (bool): Whether to terminate without stop token
            override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults

        Returns:
            XGrammarProcessor: Configured grammar processor instance
        """
        return XGrammarProcessor(
            compiled_grammar=compiled_grammar,
            terminate_without_stop_token=terminate_without_stop_token,
            override_stop_tokens=override_stop_tokens,
            vocab_size=self.vocab_size,
            batch_size=self.batch_size,
            splitwise_role=self.splitwise_role,
        )

    def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
        """
        Compile JSON schema into a grammar processor.

        Args:
            schemata (str): JSON schema string to compile

        Returns:
            Optional[XGrammarProcessor]: Configured processor if successful, None on failure
        """
        try:
            compiled_grammar = self.grammar_compiler.compile_json_schema(
                schemata, any_whitespace=self.any_whitespace)
        except Exception as e:
            llm_logger.error(f"Failed to compile json schema: {e}")
            return None
        return self._create_processor(compiled_grammar)

    def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
        """
        Compile regex pattern into a grammar processor.

        Args:
            schemata (str): Regex pattern string to compile

        Returns:
            Optional[XGrammarProcessor]: Configured processor if successful, None on failure
        """
        try:
            compiled_grammar = self.grammar_compiler.compile_regex(schemata)
        except Exception as e:
            llm_logger.error(f"Failed to compile regex schema: {e}")
            return None
        return self._create_processor(compiled_grammar)

    def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
        """
        Compile grammar (EBNF) into a grammar processor.

        Args:
            schemata (str): Grammar string in EBNF format

        Returns:
            Optional[XGrammarProcessor]: Configured processor if successful, None on failure
        """
        try:
            compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
        except Exception as e:
            llm_logger.error(f"Failed to compile ebnf schema: {e}")
            return None
        return self._create_processor(compiled_grammar)

    def _structural_tag_processor(
            self, schemata: str) -> Optional[XGrammarProcessor]:
        """
        Compile structural tags into a grammar processor.

        Args:
            schemata (str): JSON string containing structural tag definitions

        Returns:
            Optional[XGrammarProcessor]: Configured processor if successful, None on failure
        """
        try:
            structural_tag = json.loads(schemata)
            tags = [
                StructuralTagItem(
                    begin=structure["begin"],
                    schema=json.dumps(structure["schema"]),
                    end=structure["end"],
                ) for structure in structural_tag["structures"]
            ]

            compiled_grammar = self.grammar_compiler.compile_structural_tag(
                tags, structural_tag["triggers"])
        except Exception as e:
            llm_logger.error(f"Failed to compile structural tags schema: {e}")
            return None
        return self._create_processor(compiled_grammar)


class XGrammarChecker(BaseChecker):
    """
    XGrammar-specific implementation of BaseChecker.

    This validator checks and formats various schema types (JSON, regex, grammar)
    for compatibility with XGrammar before processing.

    Attributes:
        any_whitespace (bool): Whether to allow any whitespace in JSON
    """

    def __init__(self, **kwargs):
        super().__init__()

        self.any_whitespace = not kwargs.get("disable_any_whitespace", True)

    def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
        """
        Check if JSON schema contains unsupported features.

        Args:
            schema (dict[str, Any]): JSON schema to validate

        Returns:
            bool: True if schema contains unsupported features, False otherwise
        """

        def check_object(obj: dict[str, Any]) -> bool:
            if not isinstance(obj, dict):
                return False

            if obj.get("type") in ("integer", "number") and ("multipleOf"
                                                             in obj):
                return True

            if obj.get("type") == "array" and any(
                    key in obj for key in ("uniqueItems", "contains",
                                           "minContains", "maxContains")):
                return True

            if obj.get("type") == "string" and "format" in obj:
                return True

            if obj.get("type") == "object" and any(
                    key in obj
                    for key in ("minProperties", "maxProperties",
                                "propertyNames", "patternProperties")):
                return True

            for value in obj.values():
                if isinstance(value, dict):
                    if check_object(value):
                        return True
                elif isinstance(value, list):
                    for item in value:
                        if isinstance(item, dict) and check_object(item):
                            return True
            return False

        return check_object(schema)

    def schema_format(self, request: Request):
        """
        format schema to backend specific format.
        """
        if request.guided_json:
            try:
                if not isinstance(request.guided_json, str):
                    guided_json = json.dumps(request.guided_json)
                else:
                    guided_json = request.guided_json

                Grammar.from_json_schema(guided_json,
                                         any_whitespace=self.any_whitespace)
            except RuntimeError as e:
                err_msg = f"Invalid JSON format: {guided_json}, error message: {str(e)}"
                return request, err_msg

            if self._unsupported_json_schema(guided_json):
                err_msg = f"unsupported JSON schema: {guided_json}"
                return request, err_msg

            request.guided_json = guided_json
            return request, None
        elif request.guided_grammar:
            # TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
            guided_grammar = request.guided_grammar
            try:
                Grammar.from_ebnf(guided_grammar)
            except RuntimeError as e:
                err_msg = f"Invalid grammar format: {guided_grammar}, error message: {str(e)}"
                return request, err_msg
            request.guided_grammar = guided_grammar
            return request, None
        elif request.guided_json_object:
            request.guided_json = '{"type": "object"}'
            return request, None
        elif request.guided_choice:
            try:
                escaped_choices = (re.sub(r'(["\\])', r'\\\1', c)
                                   for c in request.guided_choice)
                guided_choice = ('root ::= ' +
                                 ' | '.join(f'"{c}"' for c in escaped_choices))

                Grammar.from_ebnf(guided_choice)
            except RuntimeError as e:
                err_msg = f"Invalid choice format: {guided_choice}, error message: {str(e)}"
                return request, err_msg

            request.guided_grammar = guided_choice
            return request, None
        elif request.structural_tag:
            try:
                structural_tag = json.loads(request.structural_tag)
                tags = [
                    StructuralTagItem(
                        begin=s["begin"],
                        schema=json.dumps(s["schema"]),
                        end=s["end"],
                    ) for s in structural_tag["structures"]
                ]
                Grammar.from_structural_tag(tags, structural_tag["triggers"])
            except RuntimeError as e:
                err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {str(e)}"
                return request, err_msg
            return request, None
        else:
            # regex is not format
            return request, None
